﻿using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.Serialization;
using System.Runtime.Serialization.Formatters.Binary;
using FyndSharp.Utilities.IO;
using FyndSharp.Utilities.Serialization;
using gov.va.med.VBECS.Communication.Common;

namespace gov.va.med.VBECS.Communication.Protocols
{
    internal class BinarySerializationProtocol : IProtocol
    {
        /// <summary>
        /// Maximum length of a message.
        /// </summary>
        private const int MaxMessageLength = 128 * 1024 * 1024; //128 Megabytes.

        /// <summary>
        /// This MemoryStream object is used to collect receiving bytes to build messages.
        /// </summary>
        private MemoryStream _receiveMemoryStream = new MemoryStream();

        public byte[] GetBytes(IMessage theMsg)
        {
            //Serialize the message to a byte array
            var msgBytes = SerializeMessage(theMsg);

            //Check for message length
            var msgLength = msgBytes.Length;
            if (msgLength > MaxMessageLength)
            {
                throw new CommunicationException("Message is too big (" + msgLength + " bytes). Max allowed length is " + MaxMessageLength + " bytes.");
            }

            //Create a byte array including the length of the message (4 bytes) and serialized message content
            var theResultStream = new MemoryStream(new byte[msgLength + 4]);
            BaseTypeSerializer.Int32.Write(msgLength, theResultStream);
            var result = theResultStream.ToArray();
            Array.Copy(msgBytes, 0, result, 4, msgLength);

            //Return serialized message by this protocol
            return result;
        }

        public IEnumerable<IMessage> BuildMessages(byte[] theBytes)
        {
            //Write all received bytes to the _receiveMemoryStream
            _receiveMemoryStream.Write(theBytes, 0, theBytes.Length);
            //Create a list to collect messages
            var msgList = new List<IMessage>();
            //Read all available messages and add to messages collection
            while (read_single_message(msgList)) { }
            //Return message list
            return msgList;
        }

        public void Reset()
        {
            if (_receiveMemoryStream.Length > 0)
            {
                _receiveMemoryStream = new MemoryStream();
            }
        }

        protected virtual byte[] SerializeMessage(IMessage message)
        {
            using (var mem = new MemoryStream())
            {
                new BinaryFormatter().Serialize(mem, message);
                return mem.ToArray();
            }
        }

        protected virtual IMessage DeserializeMessage(byte[] bytes)
        {
            //Create a MemoryStream to convert bytes to a stream
            using (var mem = new MemoryStream(bytes))
            {
                //Go to head of the stream
                mem.Position = 0;

                //Deserialize the message
                var theBinaryFormatter = new BinaryFormatter
                {
                    AssemblyFormat = System.Runtime.Serialization.Formatters.FormatterAssemblyStyle.Simple,
                    Binder = new DeserializationAppDomainBinder()
                };

                //Return the deserialized message
                return (IMessage)theBinaryFormatter.Deserialize(mem);
            }
        }

        private bool read_single_message(ICollection<IMessage> messages)
        {
            //Go to the beginning of the stream
            _receiveMemoryStream.Position = 0;

            //If stream has less than 4 bytes, that means we can not even read length of the message
            //So, return false to wait more bytes from remote application.
            if (_receiveMemoryStream.Length < 4)
            {
                return false;
            }

            //Read length of the message
            var msgLength = BaseTypeSerializer.Int32.Read(_receiveMemoryStream);
            if (msgLength > MaxMessageLength)
            {
                throw new Exception("Message is too big (" + msgLength + " bytes). Max allowed length is " + MaxMessageLength + " bytes.");
            }

            //If message is zero-length (It must not be but good approach to check it)
            if (msgLength == 0)
            {
                //if no more bytes, return immediately
                if (_receiveMemoryStream.Length == 4)
                {
                    _receiveMemoryStream = new MemoryStream(); //Clear the stream
                    return false;
                }

                //Create a new memory stream from current except first 4-bytes.
                var bytes = _receiveMemoryStream.ToArray();
                _receiveMemoryStream = new MemoryStream();
                _receiveMemoryStream.Write(bytes, 4, bytes.Length - 4);
                return true;
            }

            //If all bytes of the message is not received yet, return to wait more bytes
            if (_receiveMemoryStream.Length < (4 + msgLength))
            {
                _receiveMemoryStream.Position = _receiveMemoryStream.Length;
                return false;
            }

            //Read bytes of serialized message and deserialize it
            var theMsgBytes = StreamHelper.ReadBytes(_receiveMemoryStream, msgLength);
            messages.Add(DeserializeMessage(theMsgBytes));

            //Read remaining bytes to an array
            byte[] remainingBytes = StreamHelper.ReadBytes(_receiveMemoryStream, (int)(_receiveMemoryStream.Length - (4 + msgLength)));

            //Re-create the receive memory stream and write remaining bytes
            _receiveMemoryStream = new MemoryStream();
            _receiveMemoryStream.Write(remainingBytes, 0, remainingBytes.Length);

            //Return true to re-call this method to try to read next message
            return (remainingBytes.Length > 4);
        }

        protected sealed class DeserializationAppDomainBinder : SerializationBinder
        {
            public override Type BindToType(string assemblyName, string typeName)
            {
                var toAssemblyName = assemblyName.Split(',')[0];
                return (from assembly in AppDomain.CurrentDomain.GetAssemblies()
                        where assembly.FullName.Split(',')[0] == toAssemblyName
                        select assembly.GetType(typeName)).FirstOrDefault();
            }
        }
    }
}
